import pandas as pd
import geopandas as gpd
import numpy as np
from rasterio.transform import Affine
from pykrige.ok import OrdinaryKriging
from rasterio.features import geometry_mask
from datetime import datetime, timedelta
import sys

# Perform spatial join for a day for all the variables
def spatial_join(boundary_gdf, year_gdf, filtered_column, year, month, day, latitude_column, longitude_column):
    print(f'Spatial Join Process Started for {year}-{month:02d}-{day:02d}')

    # Filter and clean the data
    gdf = year_gdf[(year_gdf[filtered_column] != -999) & (year_gdf[filtered_column] != -996)].dropna()

    # Extract the coordinates and values
    x_temp = gdf[longitude_column].values
    y_temp = gdf[latitude_column].values
    value_temp = gdf[filtered_column].values

    # Define the grid
    grid_longitude = np.linspace(min(x_temp), max(x_temp), num=100)
    grid_latitude = np.linspace(min(y_temp), max(y_temp), num=100)

    # Perform ordinary kriging
    ok = OrdinaryKriging(
        x_temp, y_temp, value_temp,
        variogram_model="spherical",
        verbose=False,
        enable_plotting=False,
        coordinates_type="euclidean"
    )

    Z_pk_krig, sigma_squared_p_krig = ok.execute("grid", grid_longitude, grid_latitude)

    # Calculate mean values for each boundary
    mean_values = []
    transform = Affine.translation(min(grid_longitude), min(grid_latitude)) * Affine.scale(
        (max(grid_longitude) - min(grid_longitude)) / 100, (max(grid_latitude) - min(grid_latitude)) / 100)

    for geom in boundary_gdf.geometry:
        boundary_mask = geometry_mask([geom], transform=transform, invert=True, out_shape=Z_pk_krig.shape)
        masked_data = Z_pk_krig[boundary_mask]
        mean_value = np.nanmean(masked_data) if masked_data.size > 0 else np.nan
        mean_values.append(mean_value)

    boundary_gdf[f'MEAN_{filtered_column}'] = mean_values
    boundary_gdf['Date'] = f'{year}-{month:02d}-{day:02d}'

    return boundary_gdf

# Extract the CSV file data for a single day
def process_csv_file(df, date):
    df['Date'] = pd.to_datetime(df['Date'])
    return df[df['Date'] == date]

# Process the spatial CSV file
def process_csv_spatial_file(gdf, filename, shp_filename):
    # Drop columns in the dataframe
    df = pd.DataFrame(gdf.drop(
        ['COUNTYFP', 'COUNTYNS', 'LSAD', 'CLASSFP', 'MTFCC', 'CSAFP', 'CBSAFP', 'METDIVFP', 'FUNCSTAT', 'ALAND',
         'AWATER', 'NAMELSAD', 'geometry'], axis=1))
    geometry = gdf.geometry
    # Rename the columns in the dataframe
    df.rename(columns={'GEOID': 'CntyFIPS', 'NAME': 'CntyName', 'STATEFP': 'StateFIPS', 'INTPTLAT': 'Centroid_Lat',
                       'INTPTLON': 'Centroid_Lon'}, inplace=True)
    df['StateName'] = 'Oklahoma'
    rearrange_columns = ['StateName', 'StateFIPS', 'CntyName', 'CntyFIPS', 'Centroid_Lat', 'Centroid_Lon']
    remaining_columns = [col for col in df.columns if col not in rearrange_columns]
    df = df[rearrange_columns + remaining_columns]
    df.to_csv(filename, index=False)

    # Convert the DataFrame back to GeoDataFrame for saving as GeoJSON
    # Create a new GeoDataFrame with the original geometry
    gdf = gpd.GeoDataFrame(gdf, geometry=geometry, crs='EPSG:4269')
    gdf.to_file(shp_filename, driver='GeoJSON')

# Process kriging analysis per day and save the shapefile for every single year
def environmental_kriging(csv_file_path, boundary_shape_file_path, filtered_column, latitude_column,
                            longitude_column):
    env_csv_data = gpd.read_file(csv_file_path)  # Environmental data CSV file
    boundary_data = gpd.read_file(boundary_shape_file_path)  # Boundary shape file

    # Ensure both datasets use the same CRS (EPSG:4269)
    env_csv_data = gpd.GeoDataFrame(env_csv_data, geometry=gpd.points_from_xy(env_csv_data[longitude_column],
                                                                              env_csv_data[latitude_column]),
                                    crs='EPSG:4269')
    boundary_data = boundary_data.to_crs('EPSG:4269')
    env_csv_data = env_csv_data.dropna(subset=['YEAR', 'MONTH', 'DAY'])
    env_csv_data['YEAR'] = env_csv_data['YEAR'].astype(int)
    env_csv_data['Date'] = pd.to_datetime(env_csv_data[['YEAR', 'MONTH', 'DAY']])
    years = env_csv_data['YEAR'].unique()
    yearly_data = gpd.GeoDataFrame()
    for year in years:
        start_date = datetime(year, 1, 1)
        end_date = datetime(year, 12, 31)
        delta = timedelta(days=1)

        while start_date <= end_date:
            daily_data = process_csv_file(env_csv_data, start_date)
            if daily_data.empty:
                print(f'No data found for {start_date.strftime("%Y-%m-%d")}, continuing to next date...')
                start_date += delta
                continue

            print(f'Processing CSV data for {start_date.strftime("%Y-%m-%d")}')
            daily_results = spatial_join(boundary_data.copy(), daily_data, filtered_column, year, start_date.month,
                                         start_date.day, latitude_column, longitude_column)
            yearly_data = pd.concat([yearly_data, daily_results])
            start_date += delta
        process_csv_spatial_file(yearly_data, 'kriging_data.csv', 'kriging_shape_data.geojson')



# Example usage
if __name__ == "__main__":
    # Sample file paths
    file_path = sys.argv[1]
    boundary_shp_file_path = sys.argv[2]
    target_column_name = sys.argv[3]
    latitude_column = sys.argv[4]
    longitude_column = sys.argv[5]
    print('arguments passed to the script are:',sys.argv)
    environmental_kriging(file_path, boundary_shp_file_path, target_column_name, latitude_column, longitude_column)

